#!/usr/bin/env python3
"""Compute evaluation metrics for GMO-Lennar output JSONs.

Usage:
    python metric.py --results_dir gmo_results --ctr_map dataset/ctr_week_7.json

The script walks through *all* JSON files in --results_dir, extracts
`ground_truth` and `avg_predicted_score` (or `predicted_score`) from each
record, maps them to real CTR values via the mapping JSON, and then
computes:
    • RMSE, R², MAPE
    • p10 / p20 / p30 (percentage-error thresholds)
    • 3-bucket (tertile) accuracy
    • 10-bucket (decile) accuracy

If `avg_predicted_score` is a float, it is rounded to the nearest
integer **before** mapping.
"""

import argparse
import glob
import json
import os
import re
from typing import List, Optional

import numpy as np
import pandas as pd
from sklearn.metrics import (
    mean_squared_error,
    r2_score,
    accuracy_score,
    mean_absolute_percentage_error,
)

NUM_LLMS = 3  # Number of independent LLMs in the ensemble

# --------------------------------------------------
# Helpers
# --------------------------------------------------

def load_ctr_map(path: str) -> dict:
    """Load a CTR mapping JSON where keys are percentile labels (0-100) -> CTR values.

    Returns an *identity* mapping 0-100 if *path* is None or empty, effectively
    disabling external mapping. This simplifies usage when one only cares about
    comparing the raw percentile predictions to the ground-truth labels.
    """

    if path in (None, "", "none", "None"):
        # Identity mapping for 0-100 percentiles
        return {str(i): float(i) for i in range(101)}

    with open(path, "r", encoding="utf-8") as f:
        return {str(k): float(v) for k, v in json.load(f).items()}


# ------------------------------------------------------------------
# Core collector – now works *with or without* an external ctr_map.
# If ctr_map is an identity mapping (default when not provided), the raw
# percentile values are used directly, ensuring every record (e.g. the full
# 30 per-file) is counted in the statistics.
# ------------------------------------------------------------------

def collect_results(result_files: List[str], ctr_map: Optional[dict] = None):
    gt_vals, pred_vals = [], []  # mapped CTR values
    gt_pct_vals, pred_pct_vals = [], []  # integer percentile labels (0-100)

    for fp in result_files:
        try:
            data = json.load(open(fp, "r", encoding="utf-8"))
        except Exception:
            continue

        for rec in data:
            gt_key = rec.get("ground_truth")
            # ------------------------------------------------------------------
            # Robustly resolve the *predicted* percentile label across different
            # result JSON schemas produced by our various evaluation scripts.
            # Priority order (first non-None wins):
            #   1. "avg_predicted_score" – legacy single-score key
            #   2. "predicted_score"     – legacy alias
            #   3. "mean_prediction"      – gpt_ctr_np per-record mean
            #   4. "overall_mean_prediction" – gpt_ctr persona aggregate
            #   5. mean("predictions")    – fallback: list of raw predictions
            # ------------------------------------------------------------------
            pred_key = rec.get("avg_predicted_score", rec.get("predicted_score"))

            # Newer result formats ---------------------------------------------
            if pred_key is None:
                pred_key = rec.get("mean_prediction")
            if pred_key is None:
                pred_key = rec.get("overall_mean_prediction")

            # If still None, attempt to compute mean from a list of predictions
            if pred_key is None:
                preds_list = rec.get("predictions")
                if isinstance(preds_list, (list, tuple)) and preds_list:
                    try:
                        pred_key = float(np.mean(preds_list))
                    except Exception:
                        pred_key = None

            # NEW: Handle nested persona predictions structure --------------------
            if pred_key is None:
                personas_dict = rec.get("personas")
                if isinstance(personas_dict, dict) and personas_dict:
                    persona_means = []
                    for p_name, p_info in personas_dict.items():
                        # Prefer explicit mean_prediction
                        mp = p_info.get("mean_prediction")
                        # Fallback: compute mean from raw predictions list
                        if mp is None:
                            p_preds = p_info.get("predictions")
                            if isinstance(p_preds, (list, tuple)) and p_preds:
                                try:
                                    mp = float(np.mean(p_preds))
                                except Exception:
                                    mp = None
                        if mp is not None:
                            try:
                                persona_means.append(float(mp))
                            except Exception:
                                pass
                    # Aggregate across personas if we found any means
                    if persona_means:
                        pred_key = float(np.mean(persona_means))

            # Ensure both values exist and are numeric
            if pred_key is None or gt_key is None:
                continue

            try:
                gt_float = float(gt_key)
                pr_float = float(pred_key)
            except Exception:
                continue

            # If a ctr_map is provided, attempt to map; else use raw values
            if ctr_map:
                # Round to nearest integer for mapping lookup
                gt_int = int(round(gt_float))
                pr_int = int(round(pr_float))

                def resolve_key(k_int: int):
                    k_plain = str(k_int)
                    if k_plain in ctr_map:
                        return k_plain
                    k_dot = f"{k_int}.0"
                    if k_dot in ctr_map:
                        return k_dot
                    return None

                gt_key_resolved = resolve_key(gt_int)
                pr_key_resolved = resolve_key(pr_int)

                # Skip record if mapping fails for either side
                if gt_key_resolved is None or pr_key_resolved is None:
                    continue

                gt_vals.append(ctr_map[gt_key_resolved])
                pred_vals.append(ctr_map[pr_key_resolved])

                # Store raw percentile labels as well
                gt_pct_vals.append(gt_int)
                pred_pct_vals.append(pr_int)
            else:
                # No mapping – use raw percentile (0-100) values
                gt_vals.append(gt_float)
                pred_vals.append(pr_float)

                gt_pct_vals.append(int(round(gt_float)))
                pred_pct_vals.append(int(round(pr_float)))

    return pd.DataFrame({
        "ctr_y": gt_vals,          # continuous CTR values
        "socia_pred": pred_vals,   # continuous CTR predictions
        "pct_y": gt_pct_vals,      # integer percentiles (ground truth)
        "pct_pred": pred_pct_vals  # integer percentiles (prediction)
    })


def compute_bucket_accuracy(df: pd.DataFrame):
    """Compute 3-way (tertile) and 10-way (decile) accuracy based on *percentile labels*.

    The DataFrame must contain integer percentiles in columns `pct_y` (ground truth)
    and `pct_pred` (model prediction).
    """

    # Fallback: if percentile columns are missing, revert to CTR values (old behaviour)
    if "pct_y" in df and "pct_pred" in df:
        gt = df["pct_y"]
        pred = df["pct_pred"]
    else:
        gt = df["ctr_y"]
        pred = df["socia_pred"]

    # 3-bucket
    qs3 = gt.quantile([1 / 3, 2 / 3]).values
    gt_c3 = np.digitize(gt, bins=[-np.inf, qs3[0], qs3[1], np.inf])
    pr_c3 = np.digitize(pred, bins=[-np.inf, qs3[0], qs3[1], np.inf])
    acc3 = accuracy_score(gt_c3, pr_c3)

    # 10-bucket
    qs10 = gt.quantile(np.arange(0.1, 1.0, 0.1)).values.tolist()
    bins10 = [-np.inf] + qs10 + [np.inf]
    gt_c10 = np.digitize(gt, bins=bins10)
    pr_c10 = np.digitize(pred, bins=bins10)
    acc10 = accuracy_score(gt_c10, pr_c10)

    return acc3, acc10


# --------------------------------------------------
# Budget-based RMSE evaluation (similar to webaes version)
# --------------------------------------------------

def calculate_budget_mape(json_file_path: str, output_file_path: str, ctr_map: Optional[dict] = None, max_budget: int = 100, step: int = 10, num_llms: int = NUM_LLMS):
    """Compute dataset-level MAPE (%) for budgets 0,10,…,max_budget.

    A global budget ``B`` is first converted to a *persona* budget 
    :math:`P = \max(1, \lfloor B / \text{num\_llms} \rfloor)`.

    • Persona schema → use the **first** ``P`` personas, taking only their first prediction (1 call each).
    • Flat schema    → use the first ``P`` predictions from the list.

    The metric therefore plateaus once all available personas (or predictions) have been selected.
    """

    budgets = [0] + list(range(step, max_budget + 1, step))
    # Also include budget=3 explicitly
    if 3 not in budgets:
        budgets.insert(1, 3)
    predictions_by_budget = {b: [] for b in budgets if b > 0}
    ground_truths: List[float] = []

    # Helper to map percentile label to CTR value if mapping provided
    def map_ctr(val: float) -> float:
        if ctr_map is None:
            return val
        key_int = int(round(val))
        k_plain = str(key_int)
        k_dot = f"{key_int}.0"
        if k_plain in ctr_map:
            return ctr_map[k_plain]
        if k_dot in ctr_map:
            return ctr_map[k_dot]
        # Fall back – use raw value
        return val

    with open(json_file_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for rec in data:
        gt_raw = rec.get("ground_truth")
        if gt_raw is None:
            continue
        try:
            gt_float = float(gt_raw)
        except Exception:
            continue
        ground_truths.append(map_ctr(gt_float))

        # Detect persona-based or flat list schema
        persona_dict = rec.get("personas") or rec.get("persona_predictions")
        if isinstance(persona_dict, dict) and persona_dict:
            num_personas = len(persona_dict)
            persona_to_preds = {}
            for pn, pinfo in persona_dict.items():
                # prefer explicit list
                plist = pinfo.get("predictions")
                if isinstance(plist, list):
                    persona_to_preds[pn] = plist
                else:
                    persona_to_preds[pn] = []

            ordered_personas = list(persona_to_preds.keys())
            for b in predictions_by_budget:
                # Special handling for budget=3: take only the first prediction from the first 3 personas
                if b == 3:
                    selected_vals = []
                    for persona in ordered_personas[:3]:
                        preds = persona_to_preds.get(persona, [])
                        if preds:
                            selected_vals.append(preds[0])
                    ctr_preds = [map_ctr(pv) for pv in selected_vals]
                    predictions_by_budget[b].append(float(np.mean(ctr_preds)) if ctr_preds else np.nan)
                    continue

                p = max(1, b // num_llms)  # personas we can afford with this budget
                selected_personas = ordered_personas[:p]
                selected_vals = []
                for persona in selected_personas:
                    preds = persona_to_preds.get(persona, [])
                    if preds:
                        selected_vals.append(preds[0])  # first prediction only
                # Map selected percentiles to CTR values
                ctr_preds = [map_ctr(pv) for pv in selected_vals]
                predictions_by_budget[b].append(float(np.mean(ctr_preds)) if ctr_preds else np.nan)
        else:
            flat_preds = rec.get("predictions")
            if not isinstance(flat_preds, list) or len(flat_preds) == 0:
                for b in predictions_by_budget:
                    predictions_by_budget[b].append(np.nan)
            else:
                for b in predictions_by_budget:
                    p = max(1, b // num_llms)
                    ctr_preds = [map_ctr(pv) for pv in flat_preds[:p]]
                    predictions_by_budget[b].append(float(np.mean(ctr_preds)) if ctr_preds else np.nan)

    # --- Compute MAPE for each budget ---
    budget_results = {}
    y_true = np.array(ground_truths, dtype=float)

    for b in budgets:
        if b == 0:
            budget_results[b] = 0.0  # With zero calls we define MAPE = 0 (baseline)
            continue
        preds = np.array(predictions_by_budget[b], dtype=float)
        mask = ~np.isnan(preds)
        if mask.sum() < 2:
            budget_results[b] = np.nan
            continue
        # Avoid division by zero for ground truth
        gt_masked = y_true[mask]
        pr_masked = preds[mask]
        non_zero = gt_masked != 0
        if not non_zero.any():
            budget_results[b] = np.nan
            continue
        mape_val = mean_absolute_percentage_error(gt_masked[non_zero], pr_masked[non_zero]) * 100
        budget_results[b] = mape_val

    # Write only the MAPE values, one per line in budget order
    with open(output_file_path, "w", encoding="utf-8") as f_out:
        for b in budgets:
            val = budget_results[b]
            if np.isnan(val):
                f_out.write("NaN\n")
            else:
                f_out.write(f"{val:.2f}\n")

    return budget_results

# --------------------------------------------------
# Entry-point
# --------------------------------------------------

def main():
    # Two modes:
    #   1. Standard (default) – compute full metrics + per-file summary
    #   2. budget            – compute ONLY budget-based RMSE curves
    import sys

    if len(sys.argv) < 2:
        print("Usage:\n  python metric_ensemble_ctr.py <results_dir> [--ctr_map MAP] [--out_file OUT]\n  python metric_ensemble_ctr.py budget <json_file_or_dir> [--ctr_map MAP] [--out_file OUT]")
        return

    budget_only = False
    if sys.argv[1].lower() == "budget":
        budget_only = True
        sys.argv.pop(1)  # remove the subcommand so argparse sees remaining args

    ap = argparse.ArgumentParser()
    ap.add_argument("results_path", help="Path to JSON file or directory (depends on mode)")
    ap.add_argument("--ctr_map", default=None, help="Optional percentile→CTR mapping JSON")
    ap.add_argument("--out_file", default=None, help="Output file (defaults based on input path)")
    args = ap.parse_args()

    ctr_map = load_ctr_map(args.ctr_map)

    # ---------------- Budget-only mode ----------------
    if budget_only:
        inp = args.results_path
        if os.path.isdir(inp):
            json_files = sorted(glob.glob(os.path.join(inp, "**", "*.json"), recursive=True))
            if not json_files:
                print("[ERROR] No JSON files found in", inp)
                return
            for jf in json_files:
                out_path = args.out_file or jf.replace(".json", "_budget_metrics.txt")
                calculate_budget_mape(jf, out_path, ctr_map)
                print(f"Budget metrics → {out_path}")
        else:
            if not os.path.exists(inp):
                print("[ERROR] File not found:", inp)
                return
            out_path = args.out_file or inp.replace(".json", "_budget_metrics.txt")
            calculate_budget_mape(inp, out_path, ctr_map)
            print(f"Budget metrics → {out_path}")
        return

    # ---------------- Standard mode -------------------
    input_path = args.results_path
    if os.path.isdir(input_path):
        result_files = sorted(glob.glob(os.path.join(input_path, "**", "*.json"), recursive=True))
    else:
        if not os.path.exists(input_path):
            print("[ERROR] File not found:", input_path)
            return
        result_files = [input_path]

    if not result_files:
        print("[ERROR] No JSON result files found in", input_path)
        return

    # Aggregate dataframe for overall metrics
    df_all = collect_results(result_files, ctr_map)

    if df_all.empty:
        print("[ERROR] No valid records after mapping – check files/mapping.")
        return

    # Prepare output lines list
    out_lines: List[str] = []

    def metric_line(name: str, df: pd.DataFrame):
        y_t = df["ctr_y"]
        y_p = df["socia_pred"]
        rm = np.sqrt(mean_squared_error(y_t, y_p))
        r2v = r2_score(y_t, y_p)

        # Mean Absolute Percentage Error (sklearn) – exclude rows where y_t == 0
        non_zero_mask = y_t != 0
        if non_zero_mask.any():
            mape_val = mean_absolute_percentage_error(y_t[non_zero_mask], y_p[non_zero_mask]) * 100
        else:
            mape_val = float('nan')

        # Element-wise percentage errors (still needed for the p10 / p20 / p30 bands)
        pe = 100 * np.abs(y_t - y_p) / y_t
        df_pe = pe.replace([np.inf, -np.inf], np.nan).dropna().to_frame("ctr")
        p10v = (df_pe["ctr"] < 10).mean()
        p20v = (df_pe["ctr"] < 20).mean()
        p30v = (df_pe["ctr"] < 30).mean()
        a3, a10 = compute_bucket_accuracy(df)
        return f"{name}\tN:{len(df)}\tRMSE:{rm:.4f}\tMAPE:{mape_val:.2f}%\tR2:{r2v:.4f}\tp10:{p10v:.3f}\tp20:{p20v:.3f}\tp30:{p30v:.3f}\tAcc3:{a3:.3f}\tAcc10:{a10:.3f}"

    # Per-file metrics
    for fp in result_files:
        df_f = collect_results([fp], ctr_map)
        if df_f.empty:
            continue
        out_lines.append(metric_line(os.path.basename(fp), df_f))

    # Overall metrics line
    out_lines.append("-"*80)
    out_lines.append(metric_line("OVERALL", df_all))

    print("\n".join(out_lines))

    out_path = args.out_file or (
        os.path.join(input_path, "metrics.txt") if os.path.isdir(input_path)
        else input_path.replace(".json", "_metrics.txt"))
    try:
        with open(out_path, "w", encoding="utf-8") as f_out:
            f_out.write("\n".join(out_lines))
        print(f"\nMetrics written to {out_path}")
    except Exception as e:
        print("[WARNING] Could not write metrics file:", e)


if __name__ == "__main__":
    main() 